import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
import numpy as np
import copy
import random

#bootstrap_bs = 5
#BatchNum = 100
#NumSample = 20
#B = 50
seed = 2
npseed = 6
torchseed = 5

def LossScaledTrace(test_model, train_data, d, train_size, B=3200):
    #torch.manual_seed(seed)
    #model = copy.deepcopy(test_model)
    model = test_model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    #model = nn.DataParallel(model)
    FullGradient = torch.zeros(d, 1)
    FullGradient = FullGradient.to(device)
    #Gradients = torch.zeros(BatchNum, d * 2)
    CovarianceMatrix = torch.zeros(d, d)
    CovarianceMatrix = CovarianceMatrix.to(device)
    Hessian = torch.zeros(d, d)
    Hessian = Hessian.to(device)
    Gradients = torch.zeros(B, d)
    FullLoss = 0
    #FullLoss = FullLoss.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    # Define the all-one beta Tensor, helping to compute the per-sample gradients and losses
    beta = torch.tensor(np.ones(B), requires_grad=True)
    beta = beta.to(device)
    BetaLoss = 0
    optimizer = torch.optim.SGD(model.parameters(), lr=1)

    # Create DataLoader
    dataset = train_data
    if B > train_size:
        print("Error: The number of samples cannot be larger than the train set size!")
    else:
        sampler = RandomSampler(dataset, replacement=False, num_samples=B)
        #batch_loader = DataLoader(dataset=dataset, batch_size=bs, sampler=sampler)
        single_loader = DataLoader(dataset=dataset, batch_size=1, sampler=sampler)
        #full_loader = DataLoader(dataset=dataset, batch_size=N_train)
    '''
    # Compute the full gradient
    optimizer.zero_grad()
    FullLoss = criterion(model(inputs), labels)
    FullLoss.backward()
    FullGradient = torch.cat((model.linear.weight.grad.clone(), model.linearminus.weight.grad.clone()), 1)
    #print("The full gradient is {}".format(FullGradient))
    optimizer.zero_grad()
    '''
    # Compute the full gradient, Hessian and the Covariance Matrix
    optimizer.zero_grad()
    # Construct the BetaLoss
    for idx, (image, label) in enumerate(single_loader):
        #print(idx)
        #print(label)
        #torch.cuda.empty_cache()
        image = Variable(image)
        label = Variable(label)
        image = image.to(device)
        label = label.to(device)
        output = model(image)[0]
        output = output.to(torch.float32)
        #label = label.to(torch.float32)
        #print(type(output[0][0]))
        # BetaLoss is added by the product of per-sample loss and beta[i]
        BetaLoss += criterion(output, label) * beta[idx]

    #BetaLoss.backward()
    #print(model.module.conv1[0].weight.grad)
    position = 0
    #loss_grads = torch.autograd.grad(BetaLoss, model.parameters())#, create_graph=True)
    for param in model.parameters():
        grad = torch.autograd.grad(BetaLoss, param, create_graph=True)
        #print(grad)
        flattenedgrad = torch.flatten(grad[0])
        for i, g in enumerate(flattenedgrad):
            #print(torch.autograd.grad(flattenedgrad[i], beta, create_graph=True).size())
            Gradients[:,position] = torch.autograd.grad(flattenedgrad[i], beta, create_graph=True)[0]#.view(-1, 1)
            position += 1
            print(position)
            torch.cuda.empty_cache()
    print(Gradients)
    print('Position is {}'.format(position))

    #print(torch.autograd.grad(BetaLoss, model.module.conv1[0].weight, retain_graph=True))
    #print(model.module.conv1[0].bias)
    #print(model.module.conv1[0].bias[0])
    #biasgrad = torch.autograd.grad(BetaLoss, model.module.conv1[0].bias[0], retain_graph=True, create_graph=True)[0]
    #print(torch.autograd.grad(biasgrad, beta)[0])
    #print(torch.autograd.grad(BetaLoss, model.module.conv2[0].weight, retain_graph=True))
    #print(torch.autograd.grad(model.module.conv1[0].weight.grad, beta))
    #print(torch.autograd.grad(BetaLoss, beta))
    FullGradient = FullGradient / B
    Hessian = Hessian / B
    CovarianceMatrix = CovarianceMatrix / B
    CovarianceMatrix -= torch.mm(torch.reshape(FullGradient, (d, 1)), torch.reshape(FullGradient, (1, d)))
    FullLoss =  FullLoss / B
    print(type(torch.trace(Hessian)))
    return torch.trace(torch.mm(Hessian, CovarianceMatrix)).cpu().numpy() / max([(FullLoss * 2), 10e-15]), \
           torch.norm(Hessian, p='fro').cpu().numpy() / 1, torch.trace(Hessian).cpu().numpy() / 1
